{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# Concise Implementation for Multiple GPUs\n",
    ":label:`sec_multi_gpu_concise`\n",
    "\n",
    "Implementing parallelism from scratch for every new model is no fun. Moreover, there is significant benefit in optimizing synchronization tools for high performance. In the following we will show how to do this using DJL. The math and the algorithms are the same as in :numref:`sec_multi_gpu`. As before we begin by importing the required modules (quite unsurprisingly you will need at least two GPUs to run this notebook).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Training.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import ai.djl.metric.*;\n",
    "import org.apache.commons.lang3.ArrayUtils;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 2
   },
   "source": [
    "## A Toy Network\n",
    "\n",
    "Let us use a slightly more meaningful network than LeNet from the previous section that's still sufficiently easy and quick to train. We pick a ResNet-18 variant :cite:`He.Zhang.Ren.ea.2016`. Since the input images are tiny we modify it slightly. In particular, the difference to :numref:`sec_resnet` is that we use a smaller convolution kernel, stride, and padding at the beginning. Moreover, we remove the max-pooling layer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Residual extends AbstractBlock {\n",
    "\n",
    "    private static final byte VERSION = 2;\n",
    "\n",
    "    public ParallelBlock block;\n",
    "\n",
    "    public Residual(int numChannels, boolean use1x1Conv, Shape strideShape) {\n",
    "        super(VERSION);\n",
    "\n",
    "        SequentialBlock b1;\n",
    "        SequentialBlock conv1x1;\n",
    "\n",
    "        b1 = new SequentialBlock();\n",
    "\n",
    "        b1.add(Conv2d.builder()\n",
    "                .setFilters(numChannels)\n",
    "                .setKernelShape(new Shape(3, 3))\n",
    "                .optPadding(new Shape(1, 1))\n",
    "                .optStride(strideShape)\n",
    "                .build())\n",
    "                .add(BatchNorm.builder().build())\n",
    "                .add(Activation::relu)\n",
    "                .add(Conv2d.builder()\n",
    "                        .setFilters(numChannels)\n",
    "                        .setKernelShape(new Shape(3, 3))\n",
    "                        .optPadding(new Shape(1, 1))\n",
    "                        .build())\n",
    "                .add(BatchNorm.builder().build());\n",
    "\n",
    "        if (use1x1Conv) {\n",
    "            conv1x1 = new SequentialBlock();\n",
    "            conv1x1.add(Conv2d.builder()\n",
    "                    .setFilters(numChannels)\n",
    "                    .setKernelShape(new Shape(1, 1))\n",
    "                    .optStride(strideShape)\n",
    "                    .build());\n",
    "        } else {\n",
    "            conv1x1 = new SequentialBlock();\n",
    "            conv1x1.add(Blocks.identityBlock());\n",
    "        }\n",
    "\n",
    "        block = addChildBlock(\"residualBlock\", new ParallelBlock(\n",
    "                list -> {\n",
    "                    NDList unit = list.get(0);\n",
    "                    NDList parallel = list.get(1);\n",
    "                    return new NDList(\n",
    "                            unit.singletonOrThrow()\n",
    "                                    .add(parallel.singletonOrThrow())\n",
    "                                    .getNDArrayInternal()\n",
    "                                    .relu());\n",
    "                },\n",
    "                Arrays.asList(b1, conv1x1)));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        return block.forward(parameterStore, inputs, training);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputs) {\n",
    "        Shape[] current = inputs;\n",
    "        for (Block block : block.getChildren().values()) {\n",
    "            current = block.getOutputShapes(current);\n",
    "        }\n",
    "        return current;\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        block.initialize(manager, dataType, inputShapes);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    },
    "origin_pos": 3,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public SequentialBlock resnetBlock(int numChannels, int numResiduals, boolean isFirstBlock) {\n",
    "\n",
    "        SequentialBlock blk = new SequentialBlock();\n",
    "        for (int i = 0; i < numResiduals; i++) {\n",
    "\n",
    "            if (i == 0 && !isFirstBlock) {\n",
    "                blk.add(new Residual(numChannels, true, new Shape(2, 2)));\n",
    "            } else {\n",
    "                blk.add(new Residual(numChannels, false, new Shape(1, 1)));\n",
    "            }\n",
    "        }\n",
    "        return blk;\n",
    "}\n",
    "\n",
    "int numClass = 10;\n",
    "// This model uses a smaller convolution kernel, stride, and padding and\n",
    "// removes the maximum pooling layer\n",
    "SequentialBlock net = new SequentialBlock();\n",
    "net\n",
    "    .add(\n",
    "            Conv2d.builder()\n",
    "                    .setFilters(64)\n",
    "                    .setKernelShape(new Shape(3, 3))\n",
    "                    .optPadding(new Shape(1, 1))\n",
    "                    .build())\n",
    "    .add(BatchNorm.builder().build())\n",
    "    .add(Activation::relu)\n",
    "    .add(resnetBlock(64, 2, true))\n",
    "    .add(resnetBlock(128, 2, false))\n",
    "    .add(resnetBlock(256, 2, false))\n",
    "    .add(resnetBlock(512, 2, false))\n",
    "    .add(Pool.globalAvgPool2dBlock())\n",
    "    .add(Linear.builder().setUnits(numClass).build());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "## Parameter Initialization and Logistics\n",
    "\n",
    "The `setInitializer` method allows us to set initial defaults for parameters on a device of our choice. For a refresher see :numref:`sec_numerical_stability`. What is particularly convenient is that it also lets us initialize the network on *multiple* devices simultaneously. Let us try how this works in practice.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    },
    "origin_pos": 5,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "Model model = Model.newInstance(\"training-multiple-gpus-1\");\n",
    "model.setBlock(net);\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "\n",
    "Tracker lrt = Tracker.fixed(0.1f);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer (loss function)\n",
    "        .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer\n",
    "        .optDevices(Device.getDevices(1)) // setting the number of GPUs needed\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "Using the `split` function in the previous section we can divide a minibatch of data and copy portions to your list of devices, and this list can be retrieved from `Device` variable. We can then split the data with the help of `Batchifier` class. The network object *automatically* uses the appropriate GPU to compute the value of the forward propagation. As before we generate 4 observations and split them over the GPUs.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "4"
    },
    "origin_pos": 7,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "NDArray X = manager.randomUniform(0f, 1.0f, new Shape(4, 1, 28, 28));\n",
    "trainer.initialize(X.getShape());\n",
    "\n",
    "NDList[] res = Batchifier.STACK.split(new NDList(X), 4, true);\n",
    "\n",
    "ParameterStore parameterStore = new ParameterStore(manager, true);\n",
    "\n",
    "System.out.println(net.forward(parameterStore, new NDList(res[0]), false).singletonOrThrow());\n",
    "System.out.println(net.forward(parameterStore, new NDList(res[1]), false).singletonOrThrow());\n",
    "System.out.println(net.forward(parameterStore, new NDList(res[2]), false).singletonOrThrow());\n",
    "System.out.println(net.forward(parameterStore, new NDList(res[3]), false).singletonOrThrow());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "Once data passes through the network, the corresponding parameters are initialized *on the device the data passed through*. This means that initialization happens on a per-device basis.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    },
    "origin_pos": 9,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "net.getChildren().values().get(0).getParameters().get(\"weight\").getArray().get(new NDIndex(\"0:1\"));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 12
   },
   "source": [
    "## Training\n",
    "\n",
    "As before, the training code needs to perform a number of basic functions for efficient parallelism:\n",
    "\n",
    "* Network parameters need to be initialized across all devices.\n",
    "* While iterating over the dataset minibatches are to be divided across all devices.\n",
    "* We compute the loss and its gradient in parallel across devices.\n",
    "* Losses are aggregated (by the trainer method) and parameters are updated accordingly.\n",
    "\n",
    "In the end we compute the accuracy (again in parallel) to report the final value of the network. The training routine is quite similar to implementations in previous chapters, except that we need to split and aggregate data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "\n",
    "double[] testAccuracy;\n",
    "double[] epochCount;\n",
    "\n",
    "epochCount = new double[numEpochs];\n",
    "\n",
    "for (int i = 0; i < epochCount.length; i++) {\n",
    "    epochCount[i] = (i + 1);\n",
    "}\n",
    "\n",
    "Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "double avgTrainTimePerEpoch = 0;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "7"
    },
    "origin_pos": 13,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public void train(int numEpochs, Trainer trainer, int batchSize) throws IOException, TranslateException {\n",
    "\n",
    "    FashionMnist trainIter = FashionMnist.builder()\n",
    "            .optUsage(Dataset.Usage.TRAIN)\n",
    "            .setSampling(batchSize, true)\n",
    "            .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "            .build();\n",
    "    FashionMnist testIter = FashionMnist.builder()\n",
    "            .optUsage(Dataset.Usage.TEST)\n",
    "            .setSampling(batchSize, true)\n",
    "            .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "            .build();\n",
    "\n",
    "    trainIter.prepare();\n",
    "    testIter.prepare();\n",
    "\n",
    "    Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "    double avgTrainTime = 0;\n",
    "\n",
    "    trainer.setMetrics(new Metrics());\n",
    "\n",
    "    EasyTrain.fit(trainer, numEpochs, trainIter, testIter);\n",
    "\n",
    "    Metrics metrics = trainer.getMetrics();\n",
    "\n",
    "    trainer.getEvaluators().stream()\n",
    "            .forEach(evaluator -> {\n",
    "                evaluatorMetrics.put(\"validate_epoch_\" + evaluator.getName(), metrics.getMetric(\"validate_epoch_\" + evaluator.getName()).stream()\n",
    "                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "            });\n",
    "\n",
    "    avgTrainTime = metrics.mean(\"epoch\");\n",
    "    testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n",
    "    System.out.printf(\"test acc %.2f\\n\" , testAccuracy[numEpochs-1]);\n",
    "    System.out.println(avgTrainTime / Math.pow(10, 9) + \" sec/epoch \\n\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "## Experiments\n",
    "\n",
    "Let us see how this works in practice. As a warmup we train the network on a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "8"
    },
    "origin_pos": 15,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "Table data = null;\n",
    "// We will check if we have atleast 1 GPU available. If yes, we run the training on 1 GPU.\n",
    "if (Device.getGpuCount() >= 1) {\n",
    "    train(numEpochs, trainer, 256);\n",
    "\n",
    "    data = Table.create(\"Data\");\n",
    "    data = data.addColumns(\n",
    "            DoubleColumn.create(\"X\", epochCount), \n",
    "            DoubleColumn.create(\"testAccuracy\", testAccuracy)\n",
    "    );\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// uncomment to view graph if you have 1 GPU, since the render function doesn't work inside the if condition scope.\n",
    "// render(LinePlot.create(\"\", data, \"x\", \"testAccuracy\"), \"text/html\");    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Contour Gradient Descent.](https://d2l-java-resources.s3.amazonaws.com/img/training-with-1-gpu.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Table data = Table.create(\"Data\");\n",
    "\n",
    "// We will check if we have more than 1 GPU available. If yes, we run the training on 2 GPU.\n",
    "if (Device.getGpuCount() > 1) {\n",
    "\n",
    "    X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 28, 28));\n",
    "\n",
    "    Model model = Model.newInstance(\"training-multiple-gpus-2\");\n",
    "    model.setBlock(net);\n",
    "\n",
    "    loss = Loss.softmaxCrossEntropyLoss();\n",
    "\n",
    "    Tracker lrt = Tracker.fixed(0.2f);\n",
    "    Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "    DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "                .optOptimizer(sgd) // Optimizer (loss function)\n",
    "                .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer\n",
    "                .optDevices(Device.getDevices(2)) // setting the number of GPUs needed\n",
    "                .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "    Trainer trainer = model.newTrainer(config);\n",
    "    \n",
    "    trainer.initialize(X.getShape());\n",
    "\n",
    "    Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "    double avgTrainTimePerEpoch = 0;\n",
    "\n",
    "    train(numEpochs, trainer, 512);\n",
    "    \n",
    "    data = data.addColumns(\n",
    "        DoubleColumn.create(\"X\", epochCount), \n",
    "        DoubleColumn.create(\"testAccuracy\", testAccuracy)\n",
    "    );\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// uncomment to view graph if you have 2 GPU, since the render function doesn't work inside the if condition scope.\n",
    "// render(LinePlot.create(\"\", data, \"x\", \"testAccuracy\"), \"text/html\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Contour Gradient Descent.](https://d2l-java-resources.s3.amazonaws.com/img/training-with-2-gpu.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "* Data is automatically evaluated on the devices where the data can be found.\n",
    "* Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error.\n",
    "* The optimization algorithms automatically aggregate over multiple GPUs.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this on a p2.16xlarge instance with 16 GPUs?\n",
    "2. Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 18
   },
   "source": [
    "## Summary\n",
    "\n",
    "* Data is automatically evaluated on the devices where the data can be found.\n",
    "* Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error.\n",
    "* The optimization algorithms automatically aggregate over multiple GPUs.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this on a p2.16xlarge instance with 16 GPUs?\n",
    "2. Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
